import numpy as np
from scipy.optimize import leastsq

from fitting import baseline

def lorentzian(x0, g, a):
    """Lorentzian centered at x0, with amplitude a, and HWHM g."""
    return lambda x: a / np.pi * (  g / ( (x-x0)**2 + g**2 )  )

def n_lorentzians(*p):
    N = (len(p)-1)/3
    def f(x):
        y = p[0]*np.ones(x.shape)
        i = 0
        for i in range(N):
            y += lorentzian(*p[i*3+1:i*3+4])(x)
        return y   
    return f

def grow(mask):
    """Grows regions in a 1D binary array in both directions.
    Helper function to kill noise in odmr fit."""
    return np.logical_or(np.logical_or(mask, np.append(mask[1:],False)), np.append(False,mask[:-1]))

def fit_odmr(x,y,threshold=0.5,number_of_lorentzians='auto'):
    """Attempts to fit a sum of multiple Lorentzians and returns the fit parameters (c, x0, g0, a0, x1, g1, a1,... )."""
    # first re-scale the data to the range (0,1), such that the baseline is at 0.
    # flip the data in y-direction if threshold is negative 
    y0 = baseline(y)
    yp = y - y0
    if threshold < 0:
        yp = -yp
    y_max = yp.max()
    yp = yp / y_max
    # compute crossings through a horizontal line at height 'threshold'
    mask = yp>abs(threshold)
    edges = np.where(np.logical_xor(mask, np.append(False,mask[:-1])))[0]
    if len(edges)%2 != 0:
        raise RuntimeError('found an uneven number of edges')
    if len(edges) < 2:
        raise RuntimeError('did not find a distinct peak with the given threshold')
    if number_of_lorentzians is 'auto': # try to find N automatically
        # attempt initial growth of connected regions to kill noise 
        while True:
            mask = grow(mask)
            new_edges = np.where(np.logical_xor(mask, np.append(False,mask[:-1])))[0]
            if len(new_edges) < len(edges):
                edges = new_edges
            else:
                break
    else: # if N is specified grow until number of regions =< N
        while len(edges)/2 > number_of_lorentzians:
            mask = grow(mask)
            edges = np.where(np.logical_xor(mask, np.append(False,mask[:-1])))[0]
    if len(edges)%2 != 0:
        raise RuntimeError('found an uneven number of edges')
    if len(edges) < 2:
        raise RuntimeError('did not find a distinct peak with the given threshold')
    N = len(edges)/2
    left_and_right_edges = edges.reshape((N,2))
    p = [ 0 ]
    # for every local maximum, estimate parameters of Lorentzian and append them to the list of parameters p
    for left, right in left_and_right_edges:
        g = abs(x[right] - x[left]) # FWHM
        i = y[left:right].argmax()+left # index of local minimum
        x0 = x[i] # position of local minimum
        a = y[i] * np.pi * g # height of local minimum in terms of Lorentzian parameter a
        p += [ x0, g, a ]

    p = tuple(p)

    # chi for N Lorentzians with a common baseline
    def chi(p):
        ypp = p[0]-yp
        for i in range(N):
            ypp += lorentzian(*p[i*3+1:i*3+4])(x)
        return ypp

    r = leastsq(chi, p, full_output=True)

    if r[-1] == 0:
        raise RuntimeError('least square fit did not work out')    

    p = np.array(r[0])
    
    y_amp = y_max*np.sign(threshold)

    # rescale fit parameters back to original data 
    p[0] = p[0]*y_amp + y0
    p[3::3] *= y_amp

    delta = np.diag(r[1])**0.5
    delta[0] = delta[0]*y_max
    delta[3::3] *= y_max

    return p, delta

if __name__=='__main__':
    import cPickle
    fil = open('/home/helmut/projects/NewDefect/nuclear_polarization/alignement_at_LAC/2012-09-10/zeeman/D-E/2012-09-10_0306-05_DmE_zeeman_0.122A.pys','rb')
    d=cPickle.load(fil)
    x=d['frequency']
    y=d['counts']

    """
    y_max=y.max()
    y_min=y.min()

    N=int((y_max-y_min)/y_min**0.5)

    print N

    hist, bin_edges = np.histogram(y,N)

    import pylab
    pylab.close('all')
    pylab.plot(bin_edges[:-1],hist)
    pylab.show()
    """

    y_max=y.max()
    y_min=y.min()

    threshold = 8*y_min**0.5 / (y_max-y_min)

    print threshold

    p,dp=fit_odmr(x,y,threshold=threshold)
    
    import pylab
    pylab.close('all')
    pylab.plot(x,y)
    pylab.plot(x,n_lorentzians(*p)(x),'r-')
    pylab.show()
